// Copyright (c) 2023, Arm Limited and Contributors. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
#define PW_LOG_MODULE_NAME "server"
#define PW_LOG_LEVEL       PW_LOG_LEVEL_DEBUG

#include <algorithm>
#include <array>
#include <cinttypes>
#include <cstdint>
#include <cstring>

#include "iot_socket.h"

#include "iotsdk/ip_network_api.h"

#include "iot-socket.pwpb.h"
#include "iot-socket.rpc.pwpb.h"
#include "pw_hdlc/rpc_channel.h"
#include "pw_hdlc/rpc_packets.h"
#include "pw_log/log.h"
#include "pw_rpc/channel.h"
#include "pw_rpc/server.h"
#include "pw_stream/sys_io_stream.h"
#include "pw_string/util.h"

namespace {
// Alias requests & responses into the namespace.
using CreateRequest = iotsdk::socket::pw_rpc::CreateRequest::Message;
using BindRequest = iotsdk::socket::pw_rpc::BindRequest::Message;
using ListenRequest = iotsdk::socket::pw_rpc::ListenRequest::Message;
using AcceptRequest = iotsdk::socket::pw_rpc::AcceptRequest::Message;
using ConnectRequest = iotsdk::socket::pw_rpc::ConnectRequest::Message;
using RecvRequest = iotsdk::socket::pw_rpc::RecvRequest::Message;
using SendRequest = iotsdk::socket::pw_rpc::SendRequest::Message;
using CloseRequest = iotsdk::socket::pw_rpc::CloseRequest::Message;
using GetLocalIpRequest = iotsdk::socket::pw_rpc::GetLocalIpRequest::Message;
using AcceptResponse = iotsdk::socket::pw_rpc::AcceptResponse::Message;
using RecvResponse = iotsdk::socket::pw_rpc::RecvResponse::Message;
using GenericResponse = iotsdk::socket::pw_rpc::GenericResponse::Message;
using GetLocalIpResponse = iotsdk::socket::pw_rpc::GetLocalIpResponse::Message;

// Alias in Service base class.
template <class T> using ServiceBase = iotsdk::socket::pw_rpc::pw_rpc::pwpb::IotSocketService::Service<T>;

// Packet decoding buffer size
constexpr size_t kBufferSize = 2048;

// Max recv buffer size
constexpr size_t kRecvBufferSize = 512; // set according to proto options file

// Instantiate server with one channel using pw_sys_io for I/O.
pw::stream::SysIoWriter writer;
pw::hdlc::RpcChannelOutput hdlc_channel_output(writer, pw::hdlc::kDefaultRpcAddress, "HDLC output");
std::array<pw::rpc::Channel, 1> channels{pw::rpc::Channel::Create<1>(&hdlc_channel_output)};
pw::rpc::Server server(channels);

// Used by Service::GetLocalIp and set_ip_address
bool local_ip_is_set = false;
ip_address_t local_ip;

// Service implementation
struct Service : ServiceBase<Service> {
    pw::Status Create(const CreateRequest &request, GenericResponse &response)
    {
        response.status = iotSocketCreate(request.af, request.type, request.protocol);

        if (response.status < 0) {
            PW_LOG_WARN("iotSocketCreate: %" PRId32, response.status);
        } else {
            PW_LOG_DEBUG("Socket %" PRId32 " created", response.status);
        }

        return pw::OkStatus();
    }

    pw::Status Bind(const BindRequest &request, GenericResponse &response)
    {
        response.status =
            iotSocketBind(request.socket, (const uint8_t *)request.ip.data(), request.ip_len, request.port);

        if (response.status < 0) {
            PW_LOG_WARN("iotSocketBind: %" PRId32, response.status);
        } else {
            PW_LOG_DEBUG("Socket %" PRId32 " bound", response.status);
        }

        return pw::OkStatus();
    }

    pw::Status Listen(const ListenRequest &request, GenericResponse &response)
    {
        response.status = iotSocketListen(request.socket, request.backlog);

        if (response.status < 0) {
            PW_LOG_WARN("iotSocketListen: %" PRId32, response.status);
        } else {
            PW_LOG_DEBUG("Socket %" PRId32 " listening", response.status);
        }

        return pw::OkStatus();
    }

    pw::Status Accept(const AcceptRequest &request, AcceptResponse &response)
    {
        uint16_t port;
        response.ip.resize(response.ip.max_size());
        response.ip_len = sizeof(iot_in6_addr);
        response.status = iotSocketAccept(request.socket, (uint8_t *)response.ip.data(), &response.ip_len, &port);

        if (response.status < 0) {
            PW_LOG_WARN("iotSocketAccept: %" PRId32, response.status);
        } else {
            response.port = (uint32_t)port;
            PW_LOG_DEBUG("Socket %" PRId32 " accepted connection", response.status);
        }

        return pw::OkStatus();
    }

    pw::Status Connect(const ConnectRequest &request, GenericResponse &response)
    {
        PW_LOG_INFO("Connecting");

        response.status =
            iotSocketConnect(request.socket, (const uint8_t *)request.ip.data(), request.ip_len, request.port);

        if (response.status < 0) {
            PW_LOG_WARN("iotSocketConnect: %" PRId32, response.status);
        } else {
            PW_LOG_DEBUG("Socket %" PRId32 " connected", response.status);
        }

        return pw::OkStatus();
    }

    pw::Status Recv(const RecvRequest &request, RecvResponse &response)
    {
        if (request.len > kRecvBufferSize) {
            return pw::Status::InvalidArgument();
        }
        response.buf.resize(request.len);

        response.status = iotSocketRecv(request.socket, response.buf.data(), request.len);

        if (response.status < 0) {
            PW_LOG_WARN("iotSocketRecv: %" PRId32, response.status);
        } else {
            response.buf.resize(response.status);
            PW_LOG_DEBUG("Socket %" PRId32 " read %" PRId32 " data", request.socket, response.status);
        }

        return pw::OkStatus();
    }

    pw::Status Send(const SendRequest &request, GenericResponse &response)
    {
        if (request.len > request.buf.max_size()) {
            return pw::Status::InvalidArgument();
        }

        response.status = iotSocketSend(request.socket, request.buf.data(), request.len);

        if (response.status < 0) {
            PW_LOG_WARN("iotSocketSend: %" PRId32, response.status);
        } else {
            PW_LOG_DEBUG("Socket %" PRId32 " sent %" PRId32 " data", request.socket, response.status);
        }

        return pw::OkStatus();
    }

    pw::Status Close(const CloseRequest &request, GenericResponse &response)
    {
        response.status = iotSocketClose(request.socket);

        if (response.status < 0) {
            PW_LOG_WARN("iotSocketClose: %" PRId32, response.status);
        }

        return pw::OkStatus();
    }

    pw::Status GetLocalIp(const GetLocalIpRequest &request, GetLocalIpResponse &response)
    {
        if (!local_ip_is_set) {
            response.status = -1;
            return pw::OkStatus();
        }

        response.status = 0;

        response.ip.resize(response.ip.max_size());
        memset(response.ip.data(), 0, response.ip.max_size());

        if (local_ip.type == IP_V4_ADDR) {
            response.ip_len = sizeof(local_ip.addr.ipv4);
            memcpy(response.ip.data(), &local_ip.addr.ipv4, response.ip_len);
        } else {
            response.ip_len = sizeof(local_ip.addr.ipv6);
            memcpy(response.ip.data(), local_ip.addr.ipv6, response.ip_len);
        }

        return pw::OkStatus();
    }
} service;
} // namespace

// Set the local IP address when network is brought up
void set_ip_address(const ip_address_t *ip)
{
    memcpy(&local_ip, ip, sizeof(local_ip));
    local_ip_is_set = true;
}

// RPC server entry point
void start_rpc_app()
{
    // Packet decode buffer
    static std::array<std::byte, kBufferSize> buffer;

    PW_LOG_INFO("Starting server");

    // Register service with server
    server.RegisterService(service);

    // Process packets forever.
    pw::hdlc::ReadAndProcessPackets(server, hdlc_channel_output, buffer);
}
